''' This is the general RACE with each ACE implemented as rehashed into CMS .
       We will keep enough flexibility to be able to use original RACE (with arrays) and with or without rehashing.

parameters :
  - L : number of hash tables for LSH . i.e. in context of RACE, number of repitions
  - ACE:
    - R : Range of Array
    - K : number of repetitions
    - rehash : assert(rehash => K=1)
    - topK :
'''

from Hash import *
import pdb
from scipy.sparse import csr_matrix

# DONT USE NUMPY. JUST USE TORCH

class RaceGen:
  def __init__(self, params):
    self.range = params["range"]
    self.repetitions = params["repetitions"]
    self.power = params["power"]
    self.num_classes = params["num_classes"]
    self.max_coord = params["max_coord"]
    self.min_coord = params["min_coord"]
    self.hashfunction = HashFunction.get(params["lsh_function"], num_hashes=self.power * self.repetitions)
    self.sketch_memory = {}

    for i in range(self.num_classes):
        self.sketch_memory[i] = []
        for r in range(self.repetitions)


    self.params = params
    self.class_counts = np.zeros(self.num_classes)

    self.offset_arr = np.power(self.range_hf, np.arange(self.power))

    if not self.rehash: # rehashing is not allowed. so the actual value shoud be within range
        print(self.range_hf,self.power,"||",(self.range_hf)**self.power,"<=",self.range)
        assert((self.range_hf)**self.power <= self.range)


  def sketch(self, x, y):
    '''
      x : b x d 
      y : b x 1 \in [0,num_classes)
    '''
    hashes = self.hashfunction.compute(x) # b x (power*repetitions)

    for rep in range(self.repetitions):
      hash_values  = hashes[:,rep*self.power:(rep+1)*self.power]
      for c in np.arange(self.num_classes):
          examples_perclass = hash_values[y == c]
          self.class_counts[c] += examples_perclass.shape[0]
          self.sketch_memory[c][rep].insert(examples_perclass, torch.ones((examples_perclass.shape[0], 1)))

  def get_dictionary(self):
    race_sketch = {}
    race_sketch["memory"] = self.sketch_memory
    race_sketch["hashfunction"] = self.hashfunction.get_dictionary()
    race_sketch["params"] = self.params
    race_sketch["class_counts"] = self.class_counts
    return race_sketch

  def get_hf_equations(self, hash_values, rep, chunk_size):
    W_heq, b_heq = self.hashfunction.get_equations(hash_values, rep, chunk_size)
    return W_heq, b_heq

  def get_bounding_equations(self):
    W_max = np.identity(self.hashfunction.d)
    b_max = np.ones(self.hashfunction.d) * self.max_coord
    W_min = np.identity(self.hashfunction.d) * -1
    b_min = np.ones(self.hashfunction.d) * self.min_coord * -1
    W_total = np.concatenate([W_max, W_min])
    b_total = np.concatenate([b_max, b_min])
    return W_total, b_total

  def get_equations(self, hash_values, rep, chunk_size):
    # get hash equations
    W_heq, b_heq = self.hashfunction.get_equations(hash_values, rep, chunk_size)
    # get bounding boxes
    W_max = np.identity(self.hashfunction.d)
    b_max = np.ones(self.hashfunction.d) * self.max_coord
    W_min = np.identity(self.hashfunction.d) * -1
    b_min = np.ones(self.hashfunction.d) * self.min_coord * -1
    W_total = np.concatenate([W_heq, W_max, W_min])
    b_total = np.concatenate([b_heq, b_max, b_min])
    return W_total, b_total


  def query(self, x, y):
    ''' return the K.D.E value for x w.r.t class y '''
    raise NotImplementedError
